import os 
import torch

def load_checkpoint(model,
                    load_dir,
                    tag,
                    load_module_strict=True,
                    load_optimizer_states=True,
                    load_lr_scheduler_states=True):
    r"""Load training checkpoint

    Arguments:
        load_dir: Required. Directory to load the checkpoint from
        tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
        load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and checkpoint match.
        load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. Ex. ADAM's momentum and variance
        load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
    Return:
        load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
        client_state: State dictionary used for loading required training states in the client code.
    """

    load_path, client_states = _load_checkpoint(model,
                                                load_dir,
                                                tag,
                                                load_module_strict=load_module_strict,
                                                load_optimizer_states=load_optimizer_states,
                                                load_lr_scheduler_states=load_lr_scheduler_states)

    if load_optimizer_states:
        if model.zero_optimization() and load_path is not None:
            model._load_zero_checkpoint(load_dir,
                                       tag,
                                       load_optimizer_states=load_optimizer_states)

    return load_path, client_states

def _get_ckpt_name(mpu, checkpoints_path, tag):
    mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
    ckpt_name = os.path.join(checkpoints_path,
                             str(tag),
                             'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
    return ckpt_name

def pre_load(mpu,
             load_dir,
             tag):
    load_path = _get_ckpt_name(mpu, load_dir, tag)
    checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
    return checkpoint['module']

def _load_checkpoint(model,
                     load_dir,
                     tag,
                     load_module_strict=True,
                     load_optimizer_states=True,
                     load_lr_scheduler_states=True):

    load_path = model._get_ckpt_name(load_dir, tag)

    if not os.path.exists(load_path):
        return None, None

    checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)

    model.load_module_state_dict(state_dict=checkpoint['module'],
                                strict=load_module_strict)
    if not model.zero_optimization() and load_optimizer_states:
        if model.fp16_enabled():
            model.optimizer.load_state_dict(
                checkpoint['optimizer'],
                load_optimizer_states=load_optimizer_states)
        elif load_optimizer_states:
            model.optimizer.load_state_dict(checkpoint['optimizer'])

    if load_lr_scheduler_states and model.lr_scheduler is not None:
        model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

    model.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
    model.global_steps = checkpoint['global_steps']
    model.global_samples = checkpoint.get('global_samples',
                                         model.global_steps * model.train_batch_size())
    model.skipped_steps = checkpoint['skipped_steps']
    model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
    model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
    deepspeed_states = [
        'module',
        'optimizer',
        'lr_scheduler',
        'csr_tensor_module_names',
        'skipped_steps',
        'global_steps',
        'dp_world_size',
        'mp_world_size'
    ]
    client_state = {
        key: value
        for key,
        value in checkpoint.items() if not key in deepspeed_states
    }

    return load_path, client_state
